/*
 * Copyright 2020 LinkedIn Corp.
 * Licensed under the BSD 2-Clause License (the "License").
 * See License in the project root for license information.
 */

package com.linkedin.avroutil1.compatibility;

import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericContainer;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericEnumSymbol;
import org.apache.avro.generic.GenericFixed;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.specific.SpecificData;
import org.apache.avro.specific.SpecificFixed;
import org.apache.avro.specific.SpecificRecordBase;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.stream.Collectors;


public class AvroSchemaUtil {
  private final static Logger LOG = LoggerFactory.getLogger(AvroSchemaUtil.class);
  private final static List<Schema.Type> INT_PROMOTIONS = Collections.unmodifiableList(Arrays.asList(
          Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE
  ));
  private final static List<Schema.Type> LONG_PROMOTIONS = Collections.unmodifiableList(Arrays.asList(
          Schema.Type.FLOAT, Schema.Type.DOUBLE
  ));
  private final static Schema NONE = Schema.create(Schema.Type.NULL); //marker
  private final static ClassValue<Schema> DECLARED_SCHEMAS = new ClassValue<Schema>() {
    @Override
    protected Schema computeValue(Class<?> type) {
      Schema result = null;

      //take 1 - look for static method public Schema getClassSchema() (exists in "our" output classes
      //and classes generated by vanilla avro 1.7+)
      try {
        Method getClassSchema = type.getMethod("getClassSchema");
        result = (Schema) getClassSchema.invoke(null);
      } catch (NoSuchMethodException expected) {
        LOG.debug("could not find {}.getClassSchema()", type.getName(), expected);
        //nope
      } catch (Exception | NoClassDefFoundError expected) {
        //NoClassDefFoundError can happen when trying to operate on a class generated by
        //vanilla avro under a different avro version. the class can "load" just fine and throw
        //NoClassDefFoundError over org/apache/avro/data/RecordBuilder for example only at this point.
        //we consider this and any other unexpected exception an error
        LOG.error("while looking up or invoking {}.getClassSchema()", type.getName(), expected);
      }
      if (result != null) {
        return result;
      }

      //take 2 - look directly for public static final org.apache.avro.Schema SCHEMA$ field
      //(exists in "our" output classes and classes generated by vanilla avro 1.5+)
      try {
        Field schema$ = type.getDeclaredField("SCHEMA$");
        result = (Schema) schema$.get(null);
      } catch (NoSuchFieldException expected) {
        LOG.debug("could not find {}.SCHEMA$", type.getName(), expected);
      } catch (Exception unexpected) {
        LOG.error("while looking up or accessing {}.SCHEMA$", type.getName(), unexpected);
      }
      if (result != null) {
        return result;
      }

      return NONE;
    }
  };

  private AvroSchemaUtil() {
    //util class
  }

  public static void traverseSchema(Schema schema, SchemaVisitor visitor) {
    IdentityHashMap<Object, Boolean> visited = new IdentityHashMap<>();
    traverseSchema(schema, visitor, visited);
  }

  /**
   * Returns true if a null value is allowed as the default value for a field
   * (given its schema). It is valid if and only if:
   * (1) The field's type is null, or
   * (2) The field is a union, where the first alternative type is null.
   */
  public static boolean isNullAValidDefaultForSchema(Schema schema) {
    return schema != null &&
           (schema.getType() == Schema.Type.NULL ||
            schema.getType() == Schema.Type.UNION &&
            !schema.getTypes().isEmpty() &&
            schema.getTypes().get(0).getType() == Schema.Type.NULL);
  }

  /**
   * returns true if the given value is a valid "instance" of the given schema
   * @param value a value, possibly null
   * @param schema a schema to check vs the value. required.
   * @return true if the value is an instance of the schema
   */
  public static boolean isValidValueForSchema(Object value, Schema schema) {
    if (schema == null) {
      throw new IllegalArgumentException("schema required");
    }
    Schema.Type schemaType = schema.getType();
    if (value == null) {
      //NOTHING in avro is nullable except type NULL and unions (which might have type NULL as a branch)
      if (schemaType == Schema.Type.NULL) {
        return true;
      }
      if (schemaType == Schema.Type.UNION) {
        List<Schema> branches = schema.getTypes();
        for (Schema branch : branches) {
          if (isValidValueForSchema(value, branch)) {
            return true;
          }
        }
      }
      return false;
    }

    Boolean isSpecific = isSpecific(value);
    //these handle unions
    if (Boolean.TRUE.equals(isSpecific)) {
      return SpecificData.get().validate(schema, value);
    } else {
      //unknown values get treated as generic
      return GenericData.get().validate(schema, value);
    }
  }

  /**
   * given a (parent) schema, and a field name, find the schema for that field.
   * if the field is a union, returns the (only) non-null branch of the union
   * @param parent parent schema containing field
   * @param fieldName name of the field in question
   * @return schema of the field (or non-null union branch thereof)
   */
  public static Schema findNonNullUnionBranch(Schema parent, String fieldName) {
    if (parent == null || fieldName == null || fieldName.isEmpty()) {
      throw new IllegalArgumentException("arguments must not be null/empty");
    }
    Schema.Field field = parent.getField(fieldName);
    if (field == null) {
      return null;
    }
    return findNonNullUnionBranch(field.schema());
  }

  /**
   * Given a union schema with exactly one non-null branch, return that non-null branch.
   * If the schema is not a union, return it as is.
   * @param schema a union schema containing exactly one non-null branch, or a non-union schema.
   * @return the non-null union branch, or the original schema.
   */
  public static Schema findNonNullUnionBranch(Schema schema) {
    if (schema == null) {
      throw new IllegalArgumentException("schema must not be null");
    }
    if (schema.getType() != Schema.Type.UNION) {
      return schema;  // schema is not a union.
    }
    List<Schema> branches = schema.getTypes();
    List<Schema> nonNullBranches = branches.stream().
        filter(branch -> branch.getType() != Schema.Type.NULL).collect(Collectors.toList());
    if (nonNullBranches.size() != 1) {
      throw new IllegalArgumentException(String.format("schema has %d non-null union branches, where exactly 1 is expected",
          nonNullBranches.size()));
    }
    return nonNullBranches.get(0);
  }

  /**
   * given a root schema (which may contain more named schemas defined inline)
   * returns the set of all named schemas defined by the root schema (including the
   * root schema itself, if it is a named schema) keyed by their full name
   * @param root root schema
   * @return map of all schemas defined inside root, possibly including root
   */
  public static Map<String, Schema> getAllDefinedSchemas(Schema root) {
    if (root == null) {
      throw new IllegalArgumentException("argument must not be null");
    }
    final Map<String, Schema> results = new HashMap<>(3);
    SchemaVisitor visitor = new SchemaVisitor() {
      @Override
      public void visitSchema(Schema schema) {
        if (HelperConsts.NAMED_TYPES.contains(schema.getType())) {
          results.put(schema.getFullName(), schema);
        }
      }
    };
    AvroSchemaUtil.traverseSchema(root, visitor);
    return results;
  }

  /**
   * given a set of root names schemas (which may contain more named schemas defined inline)
   * returns the set of all named schemas defined by the root schemas (including the
   * root schemas themselves) keyed by their full name
   * @param roots root named schemas
   * @return map of all schemas defined inside root, including root
   */
  public static Map<String, Schema> getAllDefinedSchemas(Collection<Schema> roots) {
    if (roots == null) {
      throw new IllegalArgumentException("argument must not be null");
    }
    Map<String, Schema> results = new HashMap<>(roots.size());
    for (Schema root : roots) {
      Map<String, Schema> definedByRoot = getAllDefinedSchemas(root);
      results.putAll(definedByRoot);
    }
    return results;
  }

  /**
   * implements "schema resolution", as outlined in the relevant section of (recent-ish) the avro specification
   * @param writer
   * @param reader
   * @param useAliases
   * @return
   */
  public static SchemaResolutionResult resolveReaderVsWriter(Schema writer, Schema reader, boolean useAliases, boolean usePromotions) {
    Schema.Type writerType = writer.getType();
    Schema.Type readerType = reader.getType();

    if (writerType == Schema.Type.UNION) {
      if (readerType == Schema.Type.UNION) {
        //(quoting the spec) if both are unions: The first schema in the reader's union that matches the selected
        //writer's union schema is recursively resolved against it. if none match, an error is signalled.
        //1st attempt is strict
        for (Schema readerUnionMember : reader.getTypes()) {
          SchemaResolutionResult directMatch = resolveReaderVsWriter(writer, readerUnionMember, false, false);
          if (directMatch != null) {
            return new SchemaResolutionResult(readerUnionMember, directMatch.getWriterMatch(), false);
          }
        }
        //2nd attempt allowing promotions and aliases
        for (Schema readerUnionMember : reader.getTypes()) {
          SchemaResolutionResult looseMatch = resolveReaderVsWriter(writer, readerUnionMember, true, true);
          //TODO - assert single match?
          if (looseMatch != null) {
            return new SchemaResolutionResult(readerUnionMember, looseMatch.getWriterMatch(), false);
          }
        }
      } else {
        //(quoting the spec) if writer's is a union, but reader's is not: If the reader's schema matches the selected
        //writer's schema, it is recursively resolved against it. If they do not match, an error is signalled.
        //1st attempt is strict
        for (Schema writerUnionMember : writer.getTypes()) {
          SchemaResolutionResult directMatch = resolveReaderVsWriter(writerUnionMember, reader, false, false);
          if (directMatch != null) {
            return new SchemaResolutionResult(reader, directMatch.getWriterMatch(), false);
          }
        }
        //2nd attempt allowing promotions and aliases
        for (Schema writerUnionMember : writer.getTypes()) {
          SchemaResolutionResult directMatch = resolveReaderVsWriter(writerUnionMember, reader, true, true);
          if (directMatch != null) {
            return new SchemaResolutionResult(reader, directMatch.getWriterMatch(), false);
          }
        }
      }
    } else if (readerType == Schema.Type.UNION) {
      //(quoting the spec) if reader's is a union, but writer's is not: The first schema in the reader's union that
      //matches the writer's schema is recursively resolved against it. If none match, an error is signalled.
      //1st attempt is strict
      for (Schema readerUnionMember : reader.getTypes()) {
        SchemaResolutionResult directMatch = resolveReaderVsWriter(writer, readerUnionMember, false, false);
        if (directMatch != null) {
          return new SchemaResolutionResult(readerUnionMember, directMatch.getWriterMatch(), false);
        }
      }
      //2nd attempt allowing promotions and aliases
      for (Schema readerUnionMember : reader.getTypes()) {
        SchemaResolutionResult looseMatch = resolveReaderVsWriter(writer, readerUnionMember, true, true);
        //TODO - assert single match?
        if (looseMatch != null) {
          return new SchemaResolutionResult(readerUnionMember, looseMatch.getWriterMatch(), false);
        }
      }
    }

    //no unions

    switch (writerType) {
      //primitives that are promotable (others are handled by default branch):
      case INT:
        if (readerType.equals(writerType)) {
          return new SchemaResolutionResult(reader, writer, false);
        } else if (usePromotions && INT_PROMOTIONS.contains(readerType)) {
          return new SchemaResolutionResult(reader, writer, true);
        }
        break;
      case LONG:
        if (readerType.equals(writerType)) {
          return new SchemaResolutionResult(reader, writer, false);
        } else if (usePromotions && LONG_PROMOTIONS.contains(readerType)) {
          return new SchemaResolutionResult(reader, writer, true);
        }
        break;
      case FLOAT:
        if (readerType.equals(writerType)) {
          return new SchemaResolutionResult(reader, writer, false);
        } else if (usePromotions && readerType == Schema.Type.DOUBLE) {
          return new SchemaResolutionResult(reader, writer, true);
        }
        break;
      case STRING:
        if (readerType.equals(writerType)) {
          return new SchemaResolutionResult(reader, writer, false);
        } else if (usePromotions && readerType == Schema.Type.BYTES) {
          return new SchemaResolutionResult(reader, writer, true);
        }
        break;
      case BYTES:
        if (readerType.equals(writerType)) {
          return new SchemaResolutionResult(reader, writer, false);
        } else if (usePromotions && readerType == Schema.Type.STRING) {
          return new SchemaResolutionResult(reader, writer, true);
        }
        break;
      //named types:
      case ENUM:
      case FIXED:
      case RECORD:
        if (readerType.equals(writerType)) {
          //fixed types need to match on size
          if (writerType == Schema.Type.FIXED) {
            if (writer.getFixedSize() != reader.getFixedSize()) {
              return null;
            }
          }
          //named types need to match on fullname, or possibly aliases
          String writerFullName = writer.getFullName();
          if (writerFullName.equals(reader.getFullName())) {
            return new SchemaResolutionResult(reader, writer, false);
          }
          if (useAliases) {
            Set<String> aliases = reader.getAliases();
            if (aliases != null && !aliases.isEmpty()) {
              for (String alias : aliases) {
                //TODO - handle "relative" aliases (which are not fullnames) as per spec
                if (writerFullName.equals(alias)) {
                  return new SchemaResolutionResult(reader, writer, false, true);
                }
              }
            }
          }
        }
        break;
      //collections:
      case ARRAY:
        if (readerType.equals(writerType)) {
          //arrays need to match on element type
          SchemaResolutionResult elementResolutionResult = resolveReaderVsWriter(
              writer.getElementType(),
              reader.getElementType(),
              useAliases,
              usePromotions
          );
          if (elementResolutionResult == null) {
            return null;
          }
          return new SchemaResolutionResult(reader, writer, false);
        }
        break;
      case MAP:
        if (readerType.equals(writerType)) {
          //maps need to match on value type
          SchemaResolutionResult valueResolutionResult = resolveReaderVsWriter(
              writer.getValueType(),
              reader.getValueType(),
              useAliases,
              usePromotions
          );
          if (valueResolutionResult == null) {
            return null;
          }
          return new SchemaResolutionResult(reader, writer, false);
        }
        break;
      //unions we dont expect to see
      case UNION:
        throw new IllegalStateException("unexpected:"  + writer);
      //rest of the primitive types:
      default:
        if (readerType.equals(writerType)) {
          return new SchemaResolutionResult(reader, writer, false);
        }
    }

    return null;
  }

  /**
   * returns the avro schema as specified on a (generated) java class.
   * avro creates SCHEMA$ fields and static getClassSchema() methods
   * on generated classes, which is what this helper method looks for.
   *
   * NOTE: enum and fixed classes generated by vanilla avro 1.4 do
   * not have a SCHEMA$ field and so will result in null.
   * may also return null for some classes generated by (vanilla) avro
   * that is a different version to the current runtime version
   * (an error will be logged with details)
   * @param generatedClass a class possibly generated by some version of avro
   * @return the {@link org.apache.avro.Schema} specified directly on the
   * given class, if any. null if none or not a class generated by avro
   * or a class generated by a different "major" version of avro.
   */
  public static Schema getDeclaredSchema(Class<?> generatedClass) {
    if (generatedClass == null) {
      throw new IllegalArgumentException("argument cannot be null");
    }
    Schema result = DECLARED_SCHEMAS.get(generatedClass);
    if (result == null || result == NONE) {
      return null;
    }
    return result;
  }

  /**
   * returns the avro schema as specified on an "avro object" (instance
   * of specific generated class or generic avro class)
   * @param avroObject an avro object of some sort
   * @return schema declared for the given object
   */
  public static Schema getDeclaredSchema(Object avroObject) {
    if (avroObject == null) {
      throw new IllegalArgumentException("argument cannot be null");
    }
    if (avroObject instanceof Enum || avroObject instanceof SpecificFixed) {
      //(possibly) generated (specific) enums and fixed have static SCHEMA$
      //(unless generated by vanilla 1.4) so we punt to by-class
      return getDeclaredSchema(avroObject.getClass());
      //TODO - look for FixedSize annotation no fixeds generated by 1.4 and generate schema
    }
    if (avroObject instanceof GenericContainer) {
      //specific and generic records go here, as do GenericArrays.
      //also generic enums and fixeds under some versions of avro (1.5+)
      return ((GenericContainer) avroObject).getSchema();
    }
    //generic enums and fixeds under avro 1.4 have no schema :-(
    //and neither would any other class that gets here
    return null;
  }

  /**
   * given a value and a union schema, determine which union branch the given value is an instance of
   * @param value some value
   * @param union union schema (of the field the value came from)
   * @return member schema of the union that the value is an instance of
   */
  public static Schema resolveUnionBranchOf(Object value, Schema union) {
    if (union == null || union.getType() != Schema.Type.UNION) {
      throw new IllegalArgumentException("union schema is required (and should be a union)");
    }

    Schema.Type valueType = null;
    String valueFullname = null; //will be set if value is of named type

    if (value == null) {
      valueType = Schema.Type.NULL;
    } else if (value instanceof Boolean) {
      valueType = Schema.Type.BOOLEAN;
    } else if (value instanceof Integer) {
      valueType = Schema.Type.INT;
    } else if (value instanceof Long) {
      valueType = Schema.Type.LONG;
    } else if (value instanceof Float) {
      valueType = Schema.Type.FLOAT;
    } else if (value instanceof Double) {
      valueType = Schema.Type.DOUBLE;
    } else if (value instanceof CharSequence) {
      valueType = Schema.Type.STRING;
    } else if (value instanceof ByteBuffer) {
      valueType = Schema.Type.BYTES;
    } else if (value instanceof List) {
      valueType = Schema.Type.ARRAY;
    } else if (value instanceof Map) {
      valueType = Schema.Type.MAP;
    } else {
      //maybe a named type
      Schema valueDeclaredSchema = getDeclaredSchema(value);
      if (valueDeclaredSchema != null) {
        valueType = valueDeclaredSchema.getType();
        if (HelperConsts.NAMED_TYPES.contains(valueType)) {
          //we expect value to match union schema EXACTLY so wont take aliases into account
          valueFullname = valueDeclaredSchema.getFullName();
        }
      }
    }

    if (valueType == null) {
      throw new IllegalStateException("unable to determine avro type for " + value.getClass().getName() + " " + value);
    }

    for (Schema candidate : union.getTypes()) {
      if (candidate.getType() != valueType) {
        continue;
      }
      //named types in unions must also match on fullname, all other types are unique in a union
      if (valueFullname != null && !valueFullname.equals(candidate.getFullName())) {
        continue;
      }
      return candidate;
    }

    StringBuilder sb = new StringBuilder();
    sb.append("unable to to find ").append(valueType);
    if (valueFullname != null) {
      sb.append(" ").append(valueFullname);
    }
    sb.append(" in union ").append(union);
    throw new IllegalStateException(sb.toString());
  }

  /**
   * checks if the value for a given schema can possibly contain
   * strings (meaning is a string, union containing string, or collections
   * containing any of the above).
   * this is important when dealing with things like Utf8 vs java.lang.Strings
   * @param schema a schema
   * @return true if value under schema could possibly involve strings
   */
  public static boolean schemaContainsString(Schema schema) {
    if (schema == null) {
      return false;
    }
    boolean hasString = false;
    switch (schema.getType()) {
      case STRING:
      case MAP: //map keys are always strings, regardless of values
        return true;
      case UNION:
        // Any member can have string?
        for(Schema branch : schema.getTypes()) {
          if (schemaContainsString(branch)) {
            return true;
          }
        }
        return false;
      case ARRAY:
        return schemaContainsString(schema.getElementType());
    }

    return false;
  }

  private static void traverseSchema(Schema schema, SchemaVisitor visitor, IdentityHashMap<Object, Boolean> visited) {
    if (visited.put(schema, Boolean.TRUE) != null) {
      return; //been there, done that
    }
    visitor.visitSchema(schema);
    switch (schema.getType()) {
      case UNION:
        for (Schema unionBranch : schema.getTypes()) {
          traverseSchema(unionBranch, visitor, visited);
        }
        return;
      case ARRAY:
        traverseSchema(schema.getElementType(), visitor, visited);
        return;
      case MAP:
        traverseSchema(schema.getValueType(), visitor, visited);
        return;
      case RECORD:
        for (Schema.Field field : schema.getFields()) {
          visitor.visitField(schema, field);
          traverseSchema(field.schema(), visitor, visited);
        }
        break;
      default:
    }
  }

  /**
   * @param value some datum
   * @return true if datum is specific record/fixed/enum of a collection thereof,
   * false if datum is any form of GenericData, null id unable to tell (primitives/wrappers)
   */
  static Boolean isSpecific(Object value) { //package access for testing
    if (value == null) {
      return null;
    }
    if (value instanceof SpecificRecordBase || value instanceof SpecificFixed || value instanceof Enum) {
      return true; //possibly some form of generated class
    }
    if (value instanceof GenericRecord || value instanceof GenericFixed || value instanceof GenericEnumSymbol) {
      //we already checked for specifics, so this means generics
      return false;
    }
    if (value instanceof Collection) {
      for (Object content : ((Collection<?>) value)) {
        Boolean result = isSpecific(content);
        if (result != null) {
          return result;
        }
      }
    }
    if (value instanceof Map) {
      for (Map.Entry<?, ?> entry : ((Map<?,?>) value).entrySet()) {
        //avro map keys are strings, so not interesting
        Boolean result = isSpecific(entry.getValue());
        if (result != null) {
          return result;
        }
      }
    }
    return null; //dont know
  }
}
